{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training and Tracking Machine Learning Models in Java with Tribuo\n",
    "\n",
    "This notebook is a companion to the JavaOne 2022 talk \"Training and Tracking Machine Learning Models in Java with Tribuo\". It covers how to train, evaluate and examine models in [Tribuo](https://tribuo.org) an open-source Java ML library. Tribuo is Apache 2.0 licensed and available on GitHub (https://github.com/oracle/tribuo).\n",
    "\n",
    "In this notebook we'll build a simple model, evaluate its performance, and examine the provenance information collected about the training and evaluation procedures. We'll then export the Tribuo model in ONNX format, before importing a model trained in scikit-learn to do the same task then evaluating both the exported Tribuo model and the imported scikit-learn model.\n",
    "\n",
    "## Setup\n",
    "\n",
    "First we need to load in some jars and import the relevant packages from Tribuo and the JDK. The jars can be built from the Tribuo source tree with `mvn package`. We'll also need MNIST in IDX format if you don't already have that.\n",
    "\n",
    "MNIST training set:\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz`\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz`\n",
    "\n",
    "MNIST test set:\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz`\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars tribuo-classification-experiments-4.3.0-jar-with-dependencies.jar\n",
    "%jars tribuo-json-4.3.0-jar-with-dependencies.jar\n",
    "%jars tribuo-onnx-4.3.0-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.file.*;\n",
    "\n",
    "import ai.onnxruntime.*;\n",
    "import com.fasterxml.jackson.databind.*;\n",
    "import com.oracle.labs.mlrg.olcut.provenance.primitives.*;\n",
    "import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;\n",
    "import com.oracle.labs.mlrg.olcut.util.Pair;\n",
    "\n",
    "import org.tribuo.*;\n",
    "import org.tribuo.classification.*;\n",
    "import org.tribuo.classification.evaluation.*;\n",
    "import org.tribuo.classification.sgd.linear.*;\n",
    "import org.tribuo.classification.sgd.objectives.LogMulticlass;\n",
    "import org.tribuo.datasource.IDXDataSource;\n",
    "import org.tribuo.math.optimisers.*;\n",
    "import org.tribuo.interop.onnx.*;\n",
    "import org.tribuo.util.Util;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training models with Tribuo\n",
    "\n",
    "Before we can train a model we need to load some data, we'll pull in MNIST in the original IDX format using the built in `IDXDataSource`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST train size = 60000, number of features = 717, number of classes = 10\n",
      "MNIST test size = 10000, number of features = 668, number of classes = 10\n"
     ]
    }
   ],
   "source": [
    "var labelFactory = new LabelFactory();\n",
    "var labelEvaluator = new LabelEvaluator();\n",
    "var mnistTrainSource = new IDXDataSource<>(Paths.get(\"train-images-idx3-ubyte.gz\"),Paths.get(\"train-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTestSource = new IDXDataSource<>(Paths.get(\"t10k-images-idx3-ubyte.gz\"),Paths.get(\"t10k-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTrain = new MutableDataset<>(mnistTrainSource);\n",
    "var mnistTest = new MutableDataset<>(mnistTestSource);\n",
    "\n",
    "\n",
    "System.out.println(String.format(\"MNIST train size = %d, number of features = %d, number of classes = %d\",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));\n",
    "System.out.println(String.format(\"MNIST test size = %d, number of features = %d, number of classes = %d\",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MNIST is 28x28 pixel handwritten digits, for a total of 784 possible features, though 67 of them are always blank (and so Tribuo ignores them). A sample of them is given below.\n",
    "\n",
    "<img src=\"mnist-3.0.1.png\" alt=\"MNIST Digits\" style=\"width: 400px;\"/>\n",
    "\n",
    "Now we can define the model we're going to train. We'll use a simple logistic regression, trained using a gradient descent algorithm called AdaGrad."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "var lrTrainer = new LinearSGDTrainer( // training algorithm name\n",
    "                                     new LogMulticlass(),  // Loss function\n",
    "                                     new AdaGrad(0.1),     // Gradient optimiser\n",
    "                                     3,                    // Number of training epochs\n",
    "                                     30000,                // Logging interval\n",
    "                                     Trainer.DEFAULT_SEED  // RNG seed\n",
    "                                    );"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training the model is simple, we supply the training dataset to the trainer's train method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training logistic regression took (00:00:06:652)\n"
     ]
    }
   ],
   "source": [
    "var lrStartTime = System.currentTimeMillis();\n",
    "var lrModel = lrTrainer.train(mnistTrain, Map.of(\"Extra\", new StringProvenance(\"Extra\",\"Some runtime information\")));\n",
    "var lrEndTime = System.currentTimeMillis();\n",
    "\n",
    "System.out.println(\"Training logistic regression took \" + Util.formatDuration(lrStartTime,lrEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation is similarly straightforward, we supply the model and test dataset to the classification evaluator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scoring logistic regression took (00:00:00:366)\n",
      "Evaluation:\n",
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "0                             980         944          36          41       0.963       0.958       0.961\n",
      "1                           1,135       1,121          14          75       0.988       0.937       0.962\n",
      "2                           1,032         906         126          88       0.878       0.911       0.894\n",
      "3                           1,010         781         229          44       0.773       0.947       0.851\n",
      "4                             982         888          94          77       0.904       0.920       0.912\n",
      "5                             892         770         122         149       0.863       0.838       0.850\n",
      "6                             958         923          35         111       0.963       0.893       0.927\n",
      "7                           1,028         954          74         112       0.928       0.895       0.911\n",
      "8                             974         853         121         229       0.876       0.788       0.830\n",
      "9                           1,009         852         157          82       0.844       0.912       0.877\n",
      "Total                      10,000       8,992       1,008       1,008\n",
      "Accuracy                                                                    0.899\n",
      "Micro Average                                                               0.899       0.899       0.899\n",
      "Macro Average                                                               0.898       0.900       0.898\n",
      "Balanced Error Rate                                                         0.102\n",
      "\n",
      "Confusion Matrix:\n",
      "               0       1       2       3       4       5       6       7       8       9\n",
      "0            944       0       0       2       1       7      21       4       0       1\n",
      "1              0   1,121       3       2       0       1       4       1       3       0\n",
      "2              5      13     906       6       8       5      16      11      59       3\n",
      "3              6       3      33     781       3      83       7      12      75       7\n",
      "4              0       7       8       0     888       1      25      16       7      30\n",
      "5              9       4       6      18       6     770      21       2      50       6\n",
      "6              4       2       9       0       5       8     923       1       6       0\n",
      "7              0      11      21       4       7       0       0     954       5      26\n",
      "8             11      22       7       8       7      36      15       6     853       9\n",
      "9              6      13       1       4      40       8       2      59      24     852\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lrStartTime = System.currentTimeMillis();\n",
    "var mnistEval = labelEvaluator.evaluate(lrModel,mnistTest);\n",
    "lrEndTime = System.currentTimeMillis();\n",
    "\n",
    "System.out.println(\"Scoring logistic regression took \" + Util.formatDuration(lrStartTime,lrEndTime));\n",
    "System.out.println(\"Evaluation:\\n\"+mnistEval.toString());\n",
    "System.out.println(\"\\nConfusion Matrix:\");\n",
    "System.out.println(mnistEval.getConfusionMatrix().toString());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The evaluation exposes accessors for all the metrics it computes, and displays a subset of them in the `toString()` method in a format suitable for human consumption. There are also methods on the evaluation to print the results in a HTML table. The confusion matrix (which shows how many of each ground truth label were predicted as another label) also has accessors and again has a method to convert it into HTML.\n",
    "\n",
    "Now let's inspect the information Tribuo captured as part of the training procedure. We'll convert it into a moderately human readable JSON format, though there are many options for displaying the provenance and for accessing it programmatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"instance-values\" : {\n",
      "    \"Extra\" : \"Some runtime information\"\n",
      "  },\n",
      "  \"tribuo-version\" : \"4.3.0\",\n",
      "  \"java-version\" : \"19-panama\",\n",
      "  \"trainer\" : {\n",
      "    \"seed\" : \"12345\",\n",
      "    \"tribuo-version\" : \"4.3.0\",\n",
      "    \"minibatchSize\" : \"1\",\n",
      "    \"train-invocation-count\" : \"0\",\n",
      "    \"is-sequence\" : \"false\",\n",
      "    \"shuffle\" : \"true\",\n",
      "    \"epochs\" : \"3\",\n",
      "    \"optimiser\" : {\n",
      "      \"epsilon\" : \"1.0E-6\",\n",
      "      \"initialLearningRate\" : \"0.1\",\n",
      "      \"initialValue\" : \"0.0\",\n",
      "      \"host-short-name\" : \"StochasticGradientOptimiser\",\n",
      "      \"class-name\" : \"org.tribuo.math.optimisers.AdaGrad\"\n",
      "    },\n",
      "    \"host-short-name\" : \"Trainer\",\n",
      "    \"class-name\" : \"org.tribuo.classification.sgd.linear.LinearSGDTrainer\",\n",
      "    \"loggingInterval\" : \"30000\",\n",
      "    \"objective\" : {\n",
      "      \"host-short-name\" : \"LabelObjective\",\n",
      "      \"class-name\" : \"org.tribuo.classification.sgd.objectives.LogMulticlass\"\n",
      "    }\n",
      "  },\n",
      "  \"os-arch\" : \"x86_64\",\n",
      "  \"trained-at\" : \"2022-10-19T10:24:53.429354-07:00\",\n",
      "  \"os-name\" : \"Mac OS X\",\n",
      "  \"dataset\" : {\n",
      "    \"num-features\" : \"717\",\n",
      "    \"num-examples\" : \"60000\",\n",
      "    \"num-outputs\" : \"10\",\n",
      "    \"tribuo-version\" : \"4.3.0\",\n",
      "    \"datasource\" : {\n",
      "      \"outputPath\" : \"/Users/apocock/Development/Tribuo/docs/talks/JavaOne-2022/train-labels-idx1-ubyte.gz\",\n",
      "      \"features-file-modified-time\" : \"2022-10-19T10:15:28.420-07:00\",\n",
      "      \"output-resource-hash\" : \"3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C\",\n",
      "      \"datasource-creation-time\" : \"2022-10-19T10:24:45.560608-07:00\",\n",
      "      \"output-file-modified-time\" : \"2022-10-19T10:15:28.457-07:00\",\n",
      "      \"idx-feature-type\" : \"UBYTE\",\n",
      "      \"outputFactory\" : {\n",
      "        \"class-name\" : \"org.tribuo.classification.LabelFactory\"\n",
      "      },\n",
      "      \"features-resource-hash\" : \"440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609\",\n",
      "      \"featuresPath\" : \"/Users/apocock/Development/Tribuo/docs/talks/JavaOne-2022/train-images-idx3-ubyte.gz\",\n",
      "      \"host-short-name\" : \"DataSource\",\n",
      "      \"class-name\" : \"org.tribuo.datasource.IDXDataSource\"\n",
      "    },\n",
      "    \"transformations\" : [ ],\n",
      "    \"is-sequence\" : \"false\",\n",
      "    \"is-dense\" : \"false\",\n",
      "    \"class-name\" : \"org.tribuo.MutableDataset\"\n",
      "  },\n",
      "  \"class-name\" : \"org.tribuo.classification.sgd.linear.LinearSGDModel\"\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "ObjectMapper objMapper = new ObjectMapper();\n",
    "objMapper = objMapper.enable(SerializationFeature.INDENT_OUTPUT);\n",
    "\n",
    "String jsonEvaluationProvenance = objMapper.writeValueAsString(ProvenanceUtil.convertToMap(lrModel.getProvenance()));\n",
    "System.out.println(jsonEvaluationProvenance);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The provenance has captured the training algorithm and it's parameters, along with the dataset including paths to the various files, timestamps and hashes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with external models\n",
    "\n",
    "Unfortunately people don't do all their machine learning in Java, so we need to interact with external systems. In Tribuo we have support for loading models trained in other systems using TensorFlow and XGBoost, along with any model that can be exported into the [ONNX format](https://onnx.ai).\n",
    "\n",
    "We can also export many Tribuo models in ONNX format for use elsewhere.\n",
    "\n",
    "We're going to import an ONNX model, compare it to the Tribuo model we just trained, and finally export our Tribuo model in ONNX so it can be deployed in other systems.\n",
    "\n",
    "Much of the complexity in working with external models is ensuring that Tribuo's view of the input data matches the external model's view. This means that we need to ensure each named feature in Tribuo is assigned the index that the external model expects. Tribuo's features are named, so if we have a `price_per_unit` feature in Tribuo, then we need to make sure that it is presented to the external model using the right index e.g., `9`. Similarly the output dimensions from a Tribuo model are named, and so we need to map from the indices of the output tensor to Tribuo's output dimension names.\n",
    "\n",
    "We're going to demonstrate this mapping by using a model trained in scikit-learn on MNIST, where the features are given the opposite index (e.g., feature `783` is supplied as feature `0` and feature `0` is supplied as feature `783`).\n",
    "\n",
    "First we load in some of the ONNX support classes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "var denseTransformer = new DenseTransformer();\n",
    "var labelTransformer = new LabelTransformer();\n",
    "var onnxSklPath = Paths.get(\"..\",\"..\",\"..\",\"tutorials\",\"external-models\",\"skl_lr_mnist.onnx\");\n",
    "var ortEnv = OrtEnvironment.getEnvironment();\n",
    "var sessionOpts = new OrtSession.SessionOptions();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we compute the inverted mapping in Tribuo, and load in the scikit-learn model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "Map<String, Integer> sklFeatMapping = new HashMap<>();\n",
    "for (int i = 0; i < 784; i++) {\n",
    "    // This MNIST model has the feature indices transposed to test a non-trivial mapping.\n",
    "    int id = (783 - i);\n",
    "    sklFeatMapping.put(String.format(\"%03d\", i), id);\n",
    "}\n",
    "Map<Label, Integer> sklOutMapping = new HashMap<>();\n",
    "for (Label l : mnistTrain.getOutputInfo().getDomain()) {\n",
    "    sklOutMapping.put(l, Integer.parseInt(l.getLabel()));\n",
    "}\n",
    "Model<Label> sklModel = ONNXExternalModel.createOnnxModel(labelFactory, sklFeatMapping, sklOutMapping, \n",
    "                    denseTransformer, labelTransformer, sessionOpts, onnxSklPath, \"float_input\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `sklModel` is a Tribuo model like any other, so we can evaluate it in the same way we evaluated the logistic regression earlier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Class                           n          tp          fn          fp      recall        prec          f1\n",
       "0                             980         963          17          46       0.983       0.954       0.968\n",
       "1                           1,135       1,112          23          37       0.980       0.968       0.974\n",
       "2                           1,032         926         106          70       0.897       0.930       0.913\n",
       "3                           1,010         916          94          98       0.907       0.903       0.905\n",
       "4                             982         910          72          64       0.927       0.934       0.930\n",
       "5                             892         776         116          83       0.870       0.903       0.886\n",
       "6                             958         910          48          55       0.950       0.943       0.946\n",
       "7                           1,028         951          77          70       0.925       0.931       0.928\n",
       "8                             974         869         105         133       0.892       0.867       0.880\n",
       "9                           1,009         922          87          89       0.914       0.912       0.913\n",
       "Total                      10,000       9,255         745         745\n",
       "Accuracy                                                                    0.926\n",
       "Micro Average                                                               0.926       0.926       0.926\n",
       "Macro Average                                                               0.924       0.925       0.924\n",
       "Balanced Error Rate                                                         0.076"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var sklEvaluation = labelEvaluator.evaluate(sklModel,mnistTest);\n",
    "sklEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Due to differences in the default hyperparameters, the scikit-learn model performs a little differently to the one trained in Tribuo, such is the way of ML systems, they are hard to replicate across platforms. With a bit of tweaking to the Tribuo model's hyperparameters we could get the same performance, but that's an exercise for another time. \n",
    "\n",
    "Now lets export the Tribuo model into ONNX format and load it back in using the same external model support."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "var lrModelPath = Paths.get(\".\",\"javaone-lr-mnist.onnx\");\n",
    "lrModel.saveONNXModel(\"org.tribuo.javaone.lr\", // namespace for the model\n",
    "                      0,                       // model version number\n",
    "                      lrModelPath              // path to save the model\n",
    "                      );"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again we'll need to define the mapping between feature names and indices for our ONNX model, but fortunately Tribuo builds all that information anyway, so we simply copy it into the right data structures:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "Map<String, Integer> mnistFeatureMap = new HashMap<>();\n",
    "for (VariableInfo f : lrModel.getFeatureIDMap()){\n",
    "    VariableIDInfo id = (VariableIDInfo) f;\n",
    "    mnistFeatureMap.put(id.getName(),id.getID());\n",
    "}\n",
    "Map<Label, Integer> mnistOutputMap = new HashMap<>();\n",
    "for (Pair<Integer,Label> l : lrModel.getOutputIDInfo()) {\n",
    "    mnistOutputMap.put(l.getB(), l.getA());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loading the model is the same as before, just with the correct mappings and path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ONNXExternalModel<Label> onnxLR = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,\n",
    "                    denseTransformer, labelTransformer, sessionOpts, lrModelPath, \"input\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can evaluate it the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Class                           n          tp          fn          fp      recall        prec          f1\n",
       "0                             980         944          36          41       0.963       0.958       0.961\n",
       "1                           1,135       1,121          14          75       0.988       0.937       0.962\n",
       "2                           1,032         906         126          88       0.878       0.911       0.894\n",
       "3                           1,010         781         229          44       0.773       0.947       0.851\n",
       "4                             982         888          94          77       0.904       0.920       0.912\n",
       "5                             892         770         122         149       0.863       0.838       0.850\n",
       "6                             958         923          35         111       0.963       0.893       0.927\n",
       "7                           1,028         954          74         112       0.928       0.895       0.911\n",
       "8                             974         853         121         229       0.876       0.788       0.830\n",
       "9                           1,009         852         157          82       0.844       0.912       0.877\n",
       "Total                      10,000       8,992       1,008       1,008\n",
       "Accuracy                                                                    0.899\n",
       "Micro Average                                                               0.899       0.899       0.899\n",
       "Macro Average                                                               0.898       0.900       0.898\n",
       "Balanced Error Rate                                                         0.102"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var onnxEvaluation = labelEvaluator.evaluate(onnxLR,mnistTest);\n",
    "onnxEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "We saw how to train & evalutate models in Tribuo, inspected the provenance captured by the training procedure, and then discussed how to import and export models from Tribuo."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "19-panama+1-13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
